{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# default_exp tree"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#hide\n",
    "%reload_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#hide\n",
    "from nbdev.showdoc import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export \n",
    "import re\n",
    "import IPython, graphviz\n",
    "\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "\n",
    "from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor\n",
    "from sklearn.tree import export_graphviz\n",
    "from copy import deepcopy\n",
    "\n",
    "from tree_evasion.core import *\n",
    "\n",
    "SEED = 41\n",
    "np.random.seed(SEED)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "_all_ = ['Tree']"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Tree\n",
    "\n",
    "> How to represent a trained Random Forest Classifier ( scikit-learn ) model?"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "class Tree:\n",
    "    def __init__(self, tree):\n",
    "        self.tree   = tree\n",
    "        \n",
    "    def is_leaf(self, index):\n",
    "        return self.tree.feature[index] < 0\n",
    "    \n",
    "    def get_left_child_index(self, index):\n",
    "        return self.tree.children_left[index]\n",
    "        \n",
    "    def get_right_child_index(self, index):\n",
    "        return self.tree.children_right[index]\n",
    "        \n",
    "    def predicate(self, index):\n",
    "        return (self.tree.feature[index], self.tree.threshold[index])\n",
    "    \n",
    "    def prediction(self, index):\n",
    "        class_prob = self.tree.value[index].ravel()\n",
    "        class_prob_sum = np.sum(class_prob)\n",
    "        return class_prob[1] / class_prob_sum\n",
    "    \n",
    "    def constraints(self, x):\n",
    "        instance = x.reshape(1, -1).astype(np.float32)\n",
    "        path = self.tree.decision_path(instance).toarray().ravel()\n",
    "        nodes = np.where(path == 1)[0][:-1]\n",
    "        \n",
    "        c = {}\n",
    "        \n",
    "        for nidx in nodes:\n",
    "            fnode = self.tree.feature[nidx]\n",
    "            \n",
    "            instance_feat = instance.ravel()[fnode]\n",
    "            \n",
    "            if instance_feat <= self.tree.threshold[nidx]:\n",
    "                if fnode not in c:\n",
    "                    c[fnode] = [-np.inf, self.tree.threshold[nidx]]\n",
    "                else:\n",
    "                    if self.tree.threshold[nidx] < c[fnode][1]:\n",
    "                        c[fnode][1] = self.tree.threshold[nidx]\n",
    "            else:\n",
    "                if fnode not in c:\n",
    "                    c[fnode] = [self.tree.threshold[nidx], np.inf]\n",
    "                else:\n",
    "                    if self.tree.threshold[nidx] > c[fnode][0]:\n",
    "                        c[fnode][0] = self.tree.threshold[nidx]\n",
    "                    \n",
    "        return c"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Export"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#hide\n",
    "from nbdev.export import notebook2script\n",
    "notebook2script()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
